import logging
import os
import numpy as np
import torch
from config import Config
from evaluate import Evaluator
from loader import load_data
from model import Torchmodel, choose_optimizer
import matplotlib.pyplot as plt
from tqdm import tqdm

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


def main(config):
    if not os.path.isdir(config["model_path"]):
        os.mkdir(config["model_path"])
    train_data = load_data(config["train_data_path"], config)
    model = Torchmodel(config)
    cuda_flag = torch.cuda.is_available()
    if cuda_flag:
        logger.info("gpu可以使用，迁移模型至gpu")
        model = model.cuda()
    optimizer = choose_optimizer(config, model)
    evaluator = Evaluator(config, model, logger)
    epoch_losses, epoch_pearsones = [[] for i in range(2)]
    for epoch in range(config["epoch"]):
        epoch += 1
        model.train()
        logger.info("epoch %d begin" % epoch)
        train_loss = []
        for batch_data in tqdm(train_data):
            optimizer.zero_grad()
            if cuda_flag:
                batch_data = [d.cuda() for d in batch_data]
            input_id1, input_id2, labels = batch_data
            loss = model(input_id1, input_id2, labels)
            # loss = torch.tensor(loss, dtype=float, requires_grad=True)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
        logger.info("epoch average loss: %f" % np.mean(train_loss))
        epoch_losses.append(np.mean(train_loss))
        epoch_pearson = evaluator.eval(epoch)
        epoch_pearsones.append(epoch_pearson)
    # plt_pearson(epoch_pearsones, config)
    model_path = os.path.join(config["model_path"], "%s_epoch_%d.pth" % (config["pretrain_model_path"].split('/')[-1], epoch))
    torch.save(model.state_dict(), model_path)
    plt_loss(epoch_losses, epoch_pearsones, config)
    return


def plt_loss(loss, pearson, config):
    plt.switch_backend('Agg')  # 后端设置'Agg' 参考：https://cloud.tencent.com/developer/article/1559466
    fig = plt.figure()  # 设置图片信息 例如：plt.figure(num = 2,figsize=(640,480))
    ax = fig.add_subplot()

    line1 = ax.plot(loss, color='green', label='loss', marker='o')
    ax.set_ylabel('loss')
    # ax.legend(loc='upper right')  # 个性化图例（颜色、形状等）

    ax2 = ax.twinx()
    line2 = ax2.plot(pearson, color='red', label='pearson', marker='o')
    ax2.set_ylabel('pearson')

    line = line1 + line2
    plt.title('pearson and loss')
    ax.set_xlabel('epoch')
    labs = [l.get_label() for l in line]
    ax.legend(line, labs, loc='upper right')

    model_name = config['pretrain_model_path'].split('/')[-1]
    plt.savefig(os.path.join(config['model_path'], "%s_epoch%s_loss.jpg" % (model_name, config['epoch'])))
    # plt.show()
    pass


if __name__ == "__main__":
    main(Config)
